package face5wap;

import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;

/**
 * 说明：
 * 博客：https://blog.csdn.net/dong_lxkm/article/details/103125742
 *     1、dl4j并没有提供像keras那样冻结某些层参数的方法，这里采用设置learningrate为0的方法，来冻结某些层的参数
 *
 *     2、这个的更新器，用的是sgd，不能用其他的（比方说Adam、Rmsprop），因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度，达不到不更新参数的目的
 *
 *     3、这里用了StackVertex，沿着第一维合并张量，也就是合并真实数据样本和Generator产生的数据样本，共同训练Discriminator
 *
 *     4、训练过程中多次update   Discriminator的参数，以便量出最大距离，让后更新Generator一次
 *
 *     5、进行10w次迭代
		6、数据被下载到C:\Users\Administrator\.deeplearning4j\data\MNIST
 */
public class OutGan {
	static int height = 28; // 输入图像高度
	static int width = 28; // 输入图像宽度
	static int channels = 3; // 输入图像通道数
	static int[] inputShape = new int[] {channels, width, height};
	public static final int batch = 8;
	static double lr = 0.01;
	static String model = "F:/face/gan.zip";
	public static void main(String[] args) throws Exception {
 
		final GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER)
				.graphBuilder()
				.backpropType(BackpropType.Standard)
				.addInputs("input1")
				.setInputTypes(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0]))
				.addLayer("g1", new BatchNormalization.Builder()
						.updater(new RmsProp(0, 1e-8, 1e-8))
						.build(), "input1")
				.addLayer("g2", new DenseLayer.Builder()
						.updater(new RmsProp(0, 1e-8, 1e-8))
						.nOut(1024)
						.build(), "g1")
				.addLayer("g3",
						new DenseLayer.Builder().nOut(4 * 3 * 28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				/*.inputPreProcessor("gen_reshape", new FeedForwardToCnnPreProcessor(28, 28, 3))*/
				.inputPreProcessor("gen_reshape", new FeedForwardToCnnPreProcessor(28 * 2, 28 * 2, 3))
				.addLayer("gen_reshape", new Upsampling2D.Builder(2)
						.build(), "g3")
				//图片过滤
				// valid (w-f + 1)/s  same ： w/s
				.addLayer("out", new ConvolutionLayer.Builder(5, 5)
						.stride(1, 1)
						//.padding(2, 2)
						.convolutionMode(ConvolutionMode.Same)
						.activation(Activation.SIGMOID)
						.updater(new RmsProp(0, 1e-8, 1e-8))
						.nIn(3)
						.nOut(3)
						.build(), "gen_reshape")
				.setOutputs("out");
 
		ComputationGraph net = new ComputationGraph(graphBuilder.build());
		/*if (new File(model).exists()) {
			net = ComputationGraph.load(new File(model), true);
		}else{*/
			net = new ComputationGraph(graphBuilder.build());
		/*}*/
		net.init();
		System.out.println(net.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		net.setListeners(new ScoreIterationListener(100));
		net.getLayers();

		String inputDataDir = "E:"+"/face/yzm";
		File trainDataFile = new File(inputDataDir + "/B");
		FileSplit trainSplit = new FileSplit(trainDataFile, NativeImageLoader.ALLOWED_FORMATS);
		ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
		ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
		trainRR.initialize(trainSplit);
		DataSetIterator train = new RecordReaderDataSetIterator(trainRR, batch, 1, 1);
		DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
		scaler.fit(train);
		train.setPreProcessor(scaler);
		train.reset();

		//DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD =  Nd4j.zeros(batch, channels);
 
		INDArray labelG = Nd4j.ones(60, 1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { batch, channels,height,width });
			MultiDataSet dataSetD = new MultiDataSet(new INDArray[] {trueExp},new INDArray[] { labelD });
			//MultiDataSet dataSetD = new MultiDataSet(z,labelD);
			/*for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });*/
			//trainG(net, dataSetD);

			if (i % 10 == 0) {
				INDArray[] samps =  net.output(Nd4j.vstack(z));
				long[] shpaes = samps[0].shape();
				INDArray[] samples = new INDArray[(int)samps.length];
				for (int j = 0; j < samps.length; j++) {
					samples[j] = samps[j];
				}
				visualize(samples);
			}

			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}
 
	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		/*net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);*/
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
	private static JFrame frame;
	private static JPanel panel;

	private static void visualize(INDArray[] samples) {
		if (frame == null) {
			frame = new JFrame();
			frame.setTitle("Viz");
			frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
			frame.setLayout(new BorderLayout());

			panel = new JPanel();

			panel.setLayout(new GridLayout(samples.length, 4, 8, 8));
			frame.add(panel, BorderLayout.CENTER);
			frame.setVisible(true);
		}

		panel.removeAll();

		for (INDArray sample : samples) {
			if(sample == null || sample.size(0) == 0){
				continue;
			}
			panel.add(imageFromINDArray(sample));
		}

		frame.revalidate();
		frame.pack();
	}

	private static JLabel getImage(INDArray tensor) {
		BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
		long[] shape = tensor.shape();
		for (int i = 0; i < 784; i++) {
			bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * tensor.getDouble(i)));
		}
		ImageIcon orig = new ImageIcon(bi);
		Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
				Image.SCALE_DEFAULT);
		ImageIcon scaled = new ImageIcon(imageScaled);

		return new JLabel(scaled);
	}
	private static JLabel imageFromINDArray(INDArray array) {
		// array = array.reshape(28, 28);
		long[] shape = array.shape();
		int height = (int)shape[2];
		int width = (int)shape[3];
		BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
		for (int x = 0; x < width; x++) {
			for (int y = 0; y < height; y++) {
				//System.out.println(array.getDouble(0, 0, y, x));
				int gray = (int) ((array.getDouble(0, 0, y, x)  + 1) * 127.5);

				// handle out of bounds pixel values
				gray = Math.min(gray, 255);
				gray = Math.max(gray, 0);

				image.getRaster().setSample(x, y, 0, gray);
			}
		}
		ImageIcon orig = new ImageIcon(image);
		Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
				Image.SCALE_DEFAULT);
		ImageIcon scaled = new ImageIcon(imageScaled);

		return new JLabel(scaled);
		//return image;
	}
}